{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Zero-shot Text Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Although SetFit was designed for few-shot learning, the method can also be applied in scenarios where no labeled data is available. The main trick is to create _synthetic examples_ that resemble the classification task, and then train a SetFit model on them. \n",
    "\n",
    "Remarkably, this simple technique typically outperforms the zero-shot pipeline in 🤗 Transformers, and can generate predictions by a factor of 5x (or more) faster!\n",
    "\n",
    "In this tutorial, we'll explore how:\n",
    "\n",
    "* SetFit can be applied for zero-shot classification\n",
    "* Adding synthetic examples can also provide a performance boost to few-shot classification."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install setfit matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To benchmark the performance of the \"zero-shot\" method, we'll use the following dataset and pretrained model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_id = \"emotion\"\n",
    "model_id = \"sentence-transformers/paraphrase-mpnet-base-v2\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll download the reference dataset from the Hugging Face Hub:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "reference_dataset = load_dataset(dataset_id)\n",
    "reference_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DatasetDict({\n",
    "    train: Dataset({\n",
    "        features: ['text', 'label'],\n",
    "        num_rows: 16000\n",
    "    })\n",
    "    validation: Dataset({\n",
    "        features: ['text', 'label'],\n",
    "        num_rows: 2000\n",
    "    })\n",
    "    test: Dataset({\n",
    "        features: ['text', 'label'],\n",
    "        num_rows: 2000\n",
    "    })\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we're set up, let's create some synthetic data to train on!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a synthetic dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first thing we need to do is create a dataset of synthetic examples. In `setfit`, we can do this by applying the `get_templated_dataset()` function to a dummy dataset. This function expects a few main things:\n",
    "\n",
    "* A list of candidate labels to classify with. We'll use the labels from the reference dataset here, but this could be anything that's relevant to the task and dataset at hand.\n",
    "* A template to generate examples with. By default, it is `\"This sentence is {}\"`, where the `{}` will be filled by one of the candidate labels\n",
    "* A sample size $N$, which will create $N$ synthetic examples per class. We find $N=8$ usually works best.\n",
    "\n",
    "Armed with this information, let's first extract some candidate labels from the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract ClassLabel feature from \"label\" column\n",
    "label_features = reference_dataset[\"train\"].features[\"label\"]\n",
    "# Label names to classify with\n",
    "candidate_labels = label_features.names\n",
    "candidate_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']\n",
    "```\n",
    "\n",
    "<Tip>\n",
    "\n",
    "Some datasets on the Hugging Face Hub don't have a `ClassLabel` feature for the label column. In these cases, you should compute the candidate labels manually by first computing the id2label mapping as follows:\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_id2label(dataset):\n",
    "    # The column with the label names\n",
    "    label_names = dataset.unique(\"label_text\")\n",
    "    # The column with the label IDs\n",
    "    label_ids = dataset.unique(\"label\")\n",
    "    id2label = dict(zip(label_ids, label_names))\n",
    "    # Sort by label ID\n",
    "    return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}\n",
    "\n",
    "id2label = get_id2label(reference_dataset[\"train\"])\n",
    "candidate_labels = list(id2label.values())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the labels, it's a simple matter to create synthetic examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "from setfit import get_templated_dataset\n",
    "\n",
    "# A dummy dataset to fill with synthetic examples\n",
    "dummy_dataset = Dataset.from_dict({})\n",
    "train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)\n",
    "train_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "Dataset({\n",
    "    features: ['text', 'label'],\n",
    "    num_rows: 48\n",
    "})\n",
    "```\n",
    "\n",
    "<Tip>\n",
    "\n",
    "You might find you can get better performance by tweaking the `template` argument from the default of `\"The sentence is {}\"` to variants like `\"This sentence is {}\"` or `\"This example is {}\"`.\n",
    "\n",
    "</Tip>\n",
    "\n",
    "\n",
    "Since our dataset has 6 classes and we chose a sample size of 8, our synthetic dataset contains $6\\times 8=48$ examples. If we take a look at a few of the examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset.shuffle()[:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "{'text': ['This sentence is love',\n",
    "  'This sentence is fear',\n",
    "  'This sentence is joy'],\n",
    " 'label': [2, 4, 1]}\n",
    "```\n",
    "\n",
    "We can see that each input takes the form of the template and has a corresponding label associated with it. \n",
    "\n",
    "Let's not train a SetFit model on these examples!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tuning the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the `SetFitModel.from_pretrained()` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from setfit import SetFitModel\n",
    "\n",
    "model = SetFitModel.from_pretrained(model_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we've downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the [Trainer](https://huggingface.co/docs/setfit/main/en/reference/trainer#setfit.Trainer) class as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from setfit import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=reference_dataset[\"test\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've created a trainer, we can train it! While we're at it, let's time how long it takes to train and evaluate the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "trainer.train()\n",
    "zeroshot_metrics = trainer.evaluate()\n",
    "zeroshot_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "***** Running training *****\n",
    "  Num examples = 1920\n",
    "  Num epochs = 1\n",
    "  Total optimization steps = 120\n",
    "  Total train batch size = 16\n",
    "***** Running evaluation *****\n",
    "{'accuracy': 0.5345}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s\n",
    "Wall time: 11 s\n",
    "```\n",
    "\n",
    "Great, now that we have a reference score let's compare against the zero-shot pipeline from 🤗 Transformers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparing against the zero-shot pipeline from 🤗 Transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "🤗 Transformers provides a zero-shot pipeline that frames text classification as a natural language inference task. Let's load the pipeline and place it on the GPU for fast inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "pipe = pipeline(\"zero-shot-classification\", device=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the model, let's generate some predictions. We'll use the same candidate labels as we did with SetFit and increase the batch size for to speed things up:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "zeroshot_preds = pipe(reference_dataset[\"test\"][\"text\"], batch_size=16, candidate_labels=candidate_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s\n",
    "Wall time: 53.1 s\n",
    "```\n",
    "\n",
    "Note that this took almost 5x longer to generate predictions than SetFit! OK, so how well does it perform? Since each prediction is a dictionary of label names ranked by score:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zeroshot_preds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "{'sequence': 'im feeling rather rotten so im not very ambitious right now',\n",
    " 'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],\n",
    " 'scores': [0.7367985844612122,\n",
    "  0.10041674226522446,\n",
    "  0.09770156443119049,\n",
    "  0.05880110710859299,\n",
    "  0.004266355652362108,\n",
    "  0.0020156768150627613]}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the `str2int()` function from the `label` column to convert them to integers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = [label_features.str2int(pred[\"labels\"][0]) for pred in zeroshot_preds]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** As noted earlier, if you're using a dataset that doesn't have a `ClassLabel` feature for the label column, you'll need to compute the label mapping manually with something like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id2label = get_id2label(reference_dataset[\"train\"])\n",
    "label2id = {v:k for k,v in id2label.items()}\n",
    "preds = [label2id[pred[\"labels\"][0]] for pred in zeroshot_preds]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The last step is to compute accuracy using 🤗 Evaluate:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "transformers_metrics = metric.compute(predictions=preds, references=reference_dataset[\"test\"][\"label\"])\n",
    "transformers_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "{'accuracy': 0.3765}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compared to SetFit, this approach performs significantly worse. Let's wrap up our analysis by combining synthetic examples with a few labeled ones."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Augmenting labeled data with synthetic examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you have a few labeled examples, adding synthetic data can often boost performance. To simulate this, let's first sample 8 labeled examples from our reference dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from setfit import sample_dataset\n",
    "\n",
    "train_dataset = sample_dataset(reference_dataset[\"train\"])\n",
    "train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Dataset({\n",
    "    features: ['text', 'label'],\n",
    "    num_rows: 48\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To warm up, we'll train a SetFit model on these true labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SetFitModel.from_pretrained(model_id)\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=reference_dataset[\"test\"]\n",
    ")\n",
    "trainer.train()\n",
    "fewshot_metrics = trainer.evaluate()\n",
    "fewshot_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "{'accuracy': 0.4705}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that for this particular dataset, the performance with true labels is _worse_ than training on synthetic examples! In our experiments, we found that the difference depends strongly on the dataset in question. Since SetFit models are fast to train, you can always try both approaches and pick the best one.\n",
    "\n",
    "In any case, let's now add some synthetic examples to our training set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)\n",
    "augmented_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Dataset({\n",
    "    features: ['text', 'label'],\n",
    "    num_rows: 96\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As before, we can train and evaluate SetFit with the augmented dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SetFitModel.from_pretrained(model_id)\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataset=augmented_dataset,\n",
    "    eval_dataset=reference_dataset[\"test\"]\n",
    ")\n",
    "trainer.train()\n",
    "augmented_metrics = trainer.evaluate()\n",
    "augmented_metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "{'accuracy': 0.613}\n",
    "```\n",
    "\n",
    "Great, this has given us a significant boost in performance and given us a few percentage points over the purely synthetic example. \n",
    "\n",
    "Let's plot the final results for comparison:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame.from_dict({\"Method\":[\"Transformers (zero-shot)\", \"SetFit (zero-shot)\", \"SetFit (augmented)\"], \"Accuracy\": [transformers_metrics[\"accuracy\"], zeroshot_metrics[\"accuracy\"], augmented_metrics[\"accuracy\"]]})\n",
    "df.plot(kind=\"barh\", x=\"Method\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![setfit_zero_shot_results](https://github.com/huggingface/setfit/assets/37621491/b02d3e62-d51c-4506-91f6-2fe9b7ef554d)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
